# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import tempfile

import numpy as np
import pytest

import tvm
import tvm.testing
from tvm import relax, tir
from tvm.contrib import utils
from tvm.relax.dpl import is_op, wildcard
from tvm.relax.testing import transform
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T

env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True)
env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True)

requires_tensorrt_codegen = pytest.mark.skipif(
    not env_checker_codegen,
    reason="TensorRT codegen not available",
)
requires_tensorrt_runtime = pytest.mark.skipif(
    not env_checker_runtime or not env_checker_runtime(),
    reason="TensorRT runtime not available",
)

# Global variable in pytest that applies markers to all tests.
pytestmark = [requires_tensorrt_codegen] + tvm.testing.requires_cuda.marks()

# Target gpu
target_str = "nvidia/nvidia-t4"
target = tvm.target.Target(target_str)
dev = tvm.cuda()


def check_executable(exec, dev, inputs, expected, entry_func_name):
    vm = relax.VirtualMachine(exec, dev)
    out = vm[entry_func_name](*inputs)
    tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)


def check_roundtrip(exec0, dev, inputs, expected, entry_func_name="main"):
    with utils.tempdir() as temp:
        exec0.mod.export_library(temp.relpath("exec.so"))
        exec1 = tvm.runtime.load_module(temp.relpath("exec.so"))
    assert exec0.stats() == exec1["stats"]()
    assert exec0.as_text() == exec1["as_text"]()

    check_executable(exec0, dev, inputs, expected, entry_func_name)
    check_executable(exec1, dev, inputs, expected, entry_func_name)


def gen_ground_truth(mod, target, dev, inputs):
    # Lower and run tuning
    # Since there is no default schedule for GPU in MS yet, this is necessary
    with target:
        seq = tvm.transform.Sequential(
            [relax.transform.LegalizeOps(), tir.transform.DefaultGPUSchedule()]
        )
        new_mod = seq(mod)
    assert relax.analysis.well_formed(new_mod)
    exec = tvm.compile(new_mod, target, params={})
    vm = relax.VirtualMachine(exec, dev)
    return vm["main"](*inputs)


@tvm.script.ir_module
class InputModule:
    @R.function
    def main(
        x: R.Tensor((16, 16), "float32"), y: R.Tensor((16, 16), "float32")
    ) -> R.Tensor((16, 16), "float32"):
        with R.dataflow():
            z1 = R.multiply(x, y)
            z2 = R.add(z1, x)
            z3 = R.add(z1, z2)
            z4 = R.multiply(z3, z2)
            z5 = R.add(z4, z1)
            R.output(z5)
        return z5


def setup_test():
    # Prepare IRModule and its input
    mod = InputModule
    assert isinstance(mod, tvm.IRModule)

    np0 = np.random.rand(16, 16).astype(np.float32)
    np1 = np.random.rand(16, 16).astype(np.float32)
    data0 = tvm.runtime.tensor(np0, dev)
    data1 = tvm.runtime.tensor(np1, dev)
    inputs = [data0, data1]

    # Ground truth should be generated before annotation
    # due to the conflict with MS task extraction
    # TODO(@sunggg): Sort this out
    expected = gen_ground_truth(mod, target, dev, inputs)
    return mod, inputs, expected


entry_func_name = tvm.testing.parameter("main", "func")


@tvm.testing.requires_gpu
@requires_tensorrt_runtime
def test_tensorrt_only(entry_func_name):
    mod, inputs, expected = setup_test()

    if entry_func_name != "main":
        mod[entry_func_name] = mod
        del mod["main"]

    # Define patterns that we want to offload to byoc
    # This test will offload entire model
    # Thus, define patterns for both `multiply` and `add` ops
    patterns = [
        ("tensorrt.multiply", is_op("relax.multiply")(wildcard(), wildcard())),
        ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
    ]

    new_mod = tvm.transform.Sequential(
        [
            relax.transform.FuseOpsByPattern(patterns),
            relax.transform.MergeCompositeFunctions(),
            relax.transform.RunCodegen(),
        ]
    )(mod)

    ex0 = tvm.compile(new_mod, target, params={})
    # Sanity check for the correctness and roundtrip
    check_roundtrip(ex0, dev, inputs, expected, entry_func_name)


@tvm.testing.requires_gpu
@requires_tensorrt_runtime
def test_mix_use_tensorrt_and_tvm():
    mod, inputs, expected = setup_test()

    # Define patterns that we want to offload to byoc
    # This test will only offload `add` op to tensorrt
    # and tune `multiply` op with MetaSchedule
    patterns = [
        ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
    ]

    # Run Codegen pass
    with tempfile.TemporaryDirectory() as work_dir:
        with target, tvm.transform.PassContext(opt_level=0):
            new_mod = tvm.transform.Sequential(
                [
                    relax.transform.FuseOpsByPattern(patterns),
                    relax.transform.MergeCompositeFunctions(),
                    relax.transform.RunCodegen(),
                    relax.transform.LegalizeOps(),
                    relax.transform.MetaScheduleTuneIRMod(
                        params={}, work_dir=work_dir, max_trials_global=8
                    ),
                    relax.transform.MetaScheduleApplyDatabase(work_dir),
                ]
            )(mod)
    assert relax.analysis.well_formed(new_mod)
    with transform.PassContext(opt_level=0):
        ex0 = tvm.compile(new_mod, target, params={})

    # Sanity check for the correctness and roundtrip
    check_roundtrip(ex0, dev, inputs, expected)


@tvm.script.ir_module
class Conv2dx2:
    @R.function
    def main(
        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
        weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
        cls = Conv2dx2
        with R.dataflow():
            lv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_tensorrt(
                data, weight1
            )
            gv: R.Tensor((16, 32, 32, 16), dtype="float16") = cls.fused_relax_nn_conv2d_tensorrt(
                lv, weight2
            )
            R.output(gv)
        return gv

    @R.function
    def fused_relax_nn_conv2d_tensorrt(
        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
        R.func_attr({"Codegen": "tensorrt", "global_symbol": "fused_relax_nn_conv2d_tensorrt"})

        @R.function
        def gv(
            data_1: R.Tensor((16, 32, 32, 16), dtype="float16"),
            weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"),
        ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
            R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": True})
            with R.dataflow():
                gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = R.nn.conv2d(
                    data_1,
                    weight1_1,
                    padding=[1, 1, 1, 1],
                    data_layout="NHWC",
                    kernel_layout="OHWI",
                    out_layout="NHWC",
                )
                R.output(gv_1)
            return gv_1

        gv1: R.Tensor((16, 32, 32, 16), dtype="float16") = gv(data, weight1)
        return gv1


@tvm.script.ir_module
class Conv2dx2_after:
    @R.function
    def main(
        data: R.Tensor((16, 32, 32, 16), dtype="float16"),
        weight1: R.Tensor((16, 3, 3, 16), dtype="float16"),
        weight2: R.Tensor((16, 3, 3, 16), dtype="float16"),
    ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
        with R.dataflow():
            lv = R.call_dps_packed(
                "fused_relax_nn_conv2d_tensorrt",
                (data, weight1),
                out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
            )
            gv = R.call_dps_packed(
                "fused_relax_nn_conv2d_tensorrt",
                (lv, weight2),
                out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"),
            )
            R.output(gv)
        return gv


def test_multiple_calls_same_extern():
    mod = relax.transform.RunCodegen()(Conv2dx2)
    tvm.ir.assert_structural_equal(mod["main"], Conv2dx2_after["main"])


def test_default_entry_func():
    """The entry function is not necessarily named "main"

    Like `test_multiple_calls_same_extern`, but the main function is
    named "func".
    """
    before_with_main = Conv2dx2
    after_with_main = relax.transform.RunCodegen()(before_with_main)

    def rename_main(mod):
        mod = mod.clone()
        mod["func"] = mod["main"].with_attr("global_symbol", "func")
        del mod["main"]
        return mod

    before_with_func = rename_main(before_with_main)
    expected_with_func = rename_main(after_with_main)
    after_with_func = relax.transform.RunCodegen()(before_with_func)

    tvm.ir.assert_structural_equal(expected_with_func["func"], after_with_func["func"])


def test_dynamic_shape():
    import tvm.relax.backend.cuda.cublas

    @I.ir_module
    class Before:
        @R.function
        def main(
            x: R.Tensor((1, 4096), dtype="float16"),
            w1: R.Tensor((4096, "r1"), dtype="float16"),
            w2: R.Tensor((4096, "r2"), dtype="float16"),
        ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
            r1 = T.int64()
            r2 = T.int64()
            cls = Before
            with R.dataflow():
                lv: R.Tensor((1, r1), dtype="float16") = cls.fused_relax_matmul_cublas(x, w1)
                lv1: R.Tensor((1, r2), dtype="float16") = cls.fused_relax_matmul_cublas(x, w2)
                gv: R.Tuple(
                    R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
                ) = (lv, lv1)
                R.output(gv)
            return gv

        @R.function
        def fused_relax_matmul_cublas(
            x: R.Tensor((1, 4096), dtype="float16"), w1: R.Tensor((4096, "r1"), dtype="float16")
        ) -> R.Tensor((1, "r1"), dtype="float16"):
            r1 = T.int64()
            R.func_attr({"Codegen": "cublas"})

            @R.function
            def gv(
                x_1: R.Tensor((1, 4096), dtype="float16"),
                w1_1: R.Tensor((4096, r1), dtype="float16"),
            ) -> R.Tensor((1, r1), dtype="float16"):
                R.func_attr({"Composite": "cublas.matmul"})
                with R.dataflow():
                    gv_1: R.Tensor((1, r1), dtype="float16") = R.matmul(x_1, w1_1, out_dtype="void")
                    R.output(gv_1)
                return gv_1

            gv1: R.Tensor((1, r1), dtype="float16") = gv(x, w1)
            return gv1

    @I.ir_module
    class Expected:
        @R.function
        def main(
            x: R.Tensor((1, 4096), dtype="float16"),
            w1: R.Tensor((4096, "r1"), dtype="float16"),
            w2: R.Tensor((4096, "r2"), dtype="float16"),
        ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
            r1 = T.int64()
            r2 = T.int64()
            with R.dataflow():
                lv = R.call_dps_packed(
                    "fused_relax_matmul_cublas",
                    (x, w1),
                    out_sinfo=R.Tensor((1, r1), dtype="float16"),
                )
                lv1 = R.call_dps_packed(
                    "fused_relax_matmul_cublas",
                    (x, w2),
                    out_sinfo=R.Tensor((1, r2), dtype="float16"),
                )
                gv: R.Tuple(
                    R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
                ) = (lv, lv1)
                R.output(gv)
            return gv

    after = relax.transform.RunCodegen()(Before)
    tvm.ir.assert_structural_equal(after["main"], Expected["main"])


def test_no_op_for_call_to_tir():
    """Calls to PrimFunc are ignored

    RunCodegen should only update calls to Relax functions annotated
    with the `"Codegen"` attribute.  Calls to any other function type
    should be ignored.

    This is a regression test.  Previous implementations performed an
    unconditional cast from `tvm::BaseFunc` to `tvm::relax::Function`,
    which produced an error.
    """

    @tvm.script.ir_module
    class Before:
        @R.function
        def main(x: R.Tensor([4], "int64")):
            R.func_attr({"relax.force_pure": True})
            _ = Before.shape_func(x)
            return x

        @T.prim_func(private=True)
        def shape_func(H: T.Buffer(T.int64(4), "int64")):
            H[T.int64(0)] = H[T.int64(0)] + T.int64(1)

    Expected = Before
    After = relax.transform.RunCodegen()(Before)
    tvm.ir.assert_structural_equal(Expected, After)


# TODO(@sunggg):  test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding)

if __name__ == "__main__":
    pytest.main([__file__])
